package org.xukai.spark.streaming.scala

import org.apache.spark.{HashPartitioner, SparkConf}
import org.apache.spark.streaming.{Seconds, StreamingContext}

/**
 * @author Administrator
 */
object UpdateStateByKeyWordCount {
  
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf()
        .setMaster("local[2]")  
        .setAppName("UpdateStateByKeyWordCount")
    val ssc = new StreamingContext(conf, Seconds(5))
    ssc.checkpoint("hdfs://spark1:9000/wordcount_checkpoint")  
    
    val lines = ssc.socketTextStream("spark1", 9999)
    val words = lines.flatMap { _.split(" ") }   
    val pairs = words.map { word => (word, 1) }
    val wordCounts = pairs.updateStateByKey((values: Seq[Int], state: Option[Int]) => {
      var newValue = state.getOrElse(0)    
      for(value <- values) {
        newValue += value
      }
      Option(newValue)  
    })
    
    wordCounts.print()


    val updateFunc = (values: Seq[Int], state: Option[Int]) => {
      val currentCount = values.sum

      val previousCount = state.getOrElse(0)

      Some(currentCount + previousCount)
    }

    val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => {
      iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
    }

    val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount").setMaster("local")
    // Create the context with a 1 second batch size


    // Initial RDD input to updateStateByKey
    val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1)))

    // Create a ReceiverInputDStream on target ip:port and count the
    // words in input stream of \n delimited test (eg. generated by 'nc')
    val wordDstream = words.map(x => (x, 1))

    // Update the cumulative count using updateStateByKey
    // This will give a Dstream made of state (which is the cumulative count of the words)
    val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc,
      new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD)
    
    ssc.start()
    ssc.awaitTermination()
  }
  
}